"""Utils for loading data."""

import os
from numpy.lib.function_base import select

import torch
import numpy as np
from torch.utils import data

from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.datasets import MNIST
from torchvision.datasets import SVHN
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torchvision.transforms as transforms


PARTITION_ROOT = None  # path to paritioned dataset
CUSTOM_DATA_ROOT = None  # path to custom dataset


class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """

    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index].long()

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)


def get_dataloader(cfg):
    print("** regular data loader")
    dataset_spec = cfg["data"]
    exp_type = dataset_spec["type"]
    dataset_name = dataset_spec["name"]
    if exp_type == "regular":
        return _get_regular_dataloaders(cfg)
    elif exp_type == "corrupt":
        print("getting corrupt data loader")
        clean_data_path = './regular_data'
        corrupt_data_path = '/project_data/datasets/CIFAR-10-C'
        corruption_type = dataset_spec['corruption_type']
        corruption_severity = dataset_spec['corruption_severity']
        normalize = dataset_spec['normalize']
        dataset = load_cifar10_image(
            corruption_type=corruption_type,
            clean_cifar_path=clean_data_path,
            corruption_cifar_path=corrupt_data_path,
            corruption_severity=corruption_severity,
            normalize=normalize
        )
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=dataset_spec["batch_size"], shuffle=True, num_workers=2
        )
        return loader
    elif exp_type == "same_data":
        loaders_1 = _get_regular_dataloaders(cfg)
        loaders_2 = _get_regular_dataloaders(cfg)
        return loaders_1 + loaders_2
    elif exp_type == "split_subset":
        all_data_loaders = _get_split_subset_dataloaders(cfg)
        return all_data_loaders
    elif exp_type == "same_subset":
        all_data_loaders = _get_split_subset_dataloaders(cfg)
        all_data_loaders[0] = all_data_loaders[1]
        return all_data_loaders
    elif exp_type == "split_data":
        print("Getting split data")
        if dataset_name == "CIFAR-10":
            if (
                "data_augmentation" in dataset_spec
                and dataset_spec["data_augmentation"]
            ):
                tr = cifar10_transform(add_pil=True)
            else:
                tr = {
                    "train": None,
                    "test": None,
                }
            base_image_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_image_{}.npy")
            base_label_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_label_{}.npy")
        elif dataset_name == "CIFAR-100":
            if (
                "data_augmentation" in dataset_spec
                and dataset_spec["data_augmentation"]
            ):
                tr = cifar10_transform(
                    add_pil=True
                )  # share the same augmentation as cifar10
            else:
                tr = {
                    "train": None,
                    "test": None,
                }
            base_image_path = os.path.join(
                PARTITION_ROOT, "CIFAR100", "{}_image_{}.npy"
            )
            base_label_path = os.path.join(
                PARTITION_ROOT, "CIFAR100", "{}_label_{}.npy"
            )
        elif dataset_name == "SVHN":
            if (
                "data_augmentation" in dataset_spec
                and dataset_spec["data_augmentation"]
            ):
                tr = svhn_transform(add_pil=True)
            else:
                tr = {
                    "train": None,
                    "test": None,
                }
            base_image_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_image_{}.npy")
            base_label_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_label_{}.npy")
        elif dataset_name == "MNIST":
            if dataset_spec["data_augmentation"]:
                tr = mnist_transform(add_pil=True)
            else:
                # tr = {
                #     "train": transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
                #     "test": transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
                # }
                tr = {
                    "train": None,
                    "test": None,
                }
            base_image_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_image_{}.npy")
            base_label_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_label_{}.npy")
        else:
            raise ValueError("Unrecognized dataset: {}".format(dataset_name))
        all_data_loaders = []
        for i in range(2):
            data_loaders_pair = []
            for mode in ["train", "test"]:
                image_path = base_image_path.format(mode, i)
                label_path = base_label_path.format(mode, i)
                if mode == "test":
                    image_path = image_path.replace("_{}".format(i), "")
                    label_path = label_path.replace("_{}".format(i), "")
                print(image_path)
                print(label_path)
                label_np = np.load(label_path)
                if 'merge_class' in dataset_spec:
                    n_class = np.max(label_np) + 1
                    n_per_super_class = n_class // dataset_spec['merge_class']
                    label_np  = label_np // n_per_super_class
                tensor_x = torch.Tensor(np.load(image_path))
                tensor_y = torch.Tensor(label_np)
                dataset = CustomTensorDataset(
                    tensors=(tensor_x, tensor_y), transform=tr[mode]
                )
                dataloader = DataLoader(
                    dataset,
                    batch_size=dataset_spec["batch_size"],
                    shuffle=True,
                    num_workers=2,
                )
                data_loaders_pair.append(dataloader)
            all_data_loaders.append(data_loaders_pair)
        return all_data_loaders
    elif exp_type == "random_subset":
        print("Getting ramdom subset of {}".format(dataset_name))
        # if dataset_name == "CIFAR-10":
        #     if (
        #         "data_augmentation" in dataset_spec
        #         and dataset_spec["data_augmentation"]
        #     ):
        #         tr = cifar10_transform(add_pil=True)
        #     else:
        #         tr = {
        #             "train": None,
        #             "test": None,
        #         }
        #     base_image_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_image_{}.npy")
        #     base_label_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_label_{}.npy")
        # elif dataset_name == "MNIST":
        #     if dataset_spec["data_augmentation"]:
        #         tr = mnist_transform(add_pil=True)
        #     else:
        #         tr = {
        #             "train": None,
        #             "test": None,
        #         }
        #     base_image_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_image_{}.npy")
        #     base_label_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_label_{}.npy")
        # elif dataset_name == "SVHN":
        #     if (
        #         "data_augmentation" in dataset_spec
        #         and dataset_spec["data_augmentation"]
        #     ):
        #         tr = svhn_transform(add_pil=True)
        #     else:
        #         tr = {
        #             "train": None,
        #             "test": None,
        #         }
        #     base_image_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_image_{}.npy")
        #     base_label_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_label_{}.npy")
        # else:
        #     raise ValueError("Unrecognized dataset: {}".format(dataset_name))
        (
            base_image_path,
            base_label_path,
            tr,
        ) = _get_base_image_path_and_augmentation_transformation(dataset_spec)

        data_loaders_pair = []
        image_path_0 = base_image_path.format("train", 0)
        image_path_1 = base_image_path.format("train", 1)
        label_path_0 = base_label_path.format("train", 0)
        label_path_1 = base_label_path.format("train", 1)
        test_image_path = base_image_path.format("test", 0).replace("_{}".format(0), "")
        test_label_path = base_label_path.format("test", 0).replace("_{}".format(0), "")

        images_np = np.concatenate([np.load(image_path_0), np.load(image_path_1)])
        label_np = np.concatenate([np.load(label_path_0), np.load(label_path_1)])

        if 'merge_class' in dataset_spec:
            n_class = np.max(label_np) + 1
            target_n_class = dataset_spec['merge_class']
            n_per_super_class = n_class // target_n_class
            label_np  = label_np // n_per_super_class
            print(f'Merging into {target_n_class} super-classes')
            print(label_np[:10])

        if 'color_group' in dataset_spec:
            boundary_idx = dataset_spec["color_group"]
            score = np.sum(images_np[:, 2, :, :], axis=(1,2)) / np.sum(images_np, axis=(1,2, 3))
            group_boundary =  np.percentile(score, (1, 20, 40, 60, 80, 99))
            assert dataset_spec['color_group'] < 5
            selected_idx = np.where(np.logical_and(score > group_boundary[boundary_idx], score < group_boundary[boundary_idx+1]))[0]
            assert dataset_spec["subset_size"] < len(selected_idx)
            images_np = images_np[selected_idx]
            label_np = label_np[selected_idx]
            print(f'Train size {len(selected_idx)}')

        all_idx = np.arange(len(images_np))
        chosen_idx = np.random.choice(
            all_idx, dataset_spec["subset_size"], replace=False
        )
        if (
            "fixed_subset_index_path" in dataset_spec
            and dataset_spec["fixed_subset_index_path"] is not None
        ):
            print(
                "Using fixed subset from: {}".format(
                    dataset_spec["fixed_subset_index_path"]
                )
            )
            if os.path.exists(dataset_spec["fixed_subset_index_path"]):
                chosen_idx = np.load(dataset_spec["fixed_subset_index_path"])
            else:
                print("Saving the index because the provide file does not exist...")
                np.save(dataset_spec["fixed_subset_index_path"], chosen_idx)

        if (
            dataset_spec["name"] == "CIFAR-10"
            and "random_label" in dataset_spec
            and dataset_spec["random_label"] == True
        ):
            random_label_path = os.path.join(
                CUSTOM_DATA_ROOT,
                "CIFAR10",
                "random_binary_label.npy",
            )
            label_np = np.load(random_label_path)
            print('Using random binary label!')

        train_tensor_x = torch.Tensor(images_np[chosen_idx])
        train_tensor_y = torch.Tensor(label_np[chosen_idx])

        test_label_np = np.load(test_label_path)
        test_data_np = np.load(test_image_path)

        if 'merge_class' in dataset_spec:
            n_class = np.max(test_label_np) + 1
            target_n_class = dataset_spec['merge_class']
            n_per_super_class = n_class // target_n_class
            test_label_np  = test_label_np // n_per_super_class

        if 'color_group' in dataset_spec:
            score = np.sum(test_data_np[:, 2, :, :], axis=(1, 2)) / np.sum(test_data_np, axis=(1, 2, 3))
            selected_idx = np.where(np.logical_and(score > group_boundary[boundary_idx], score < group_boundary[boundary_idx+1]))[0]
            test_data_np = test_data_np[selected_idx]
            test_label_np = test_label_np[selected_idx]
            print(f'Test size {len(selected_idx)}')

        test_tensor_x = torch.Tensor(test_data_np)
        test_tensor_y = torch.Tensor(test_label_np)

        dataset = CustomTensorDataset(
            tensors=(train_tensor_x, train_tensor_y), transform=tr["train"]
        )
        train_dataloader = DataLoader(
            dataset, batch_size=dataset_spec["batch_size"], shuffle=True, num_workers=2,
        )
        dataset = CustomTensorDataset(
            tensors=(test_tensor_x, test_tensor_y), transform=tr["test"]
        )
        test_dataloader = DataLoader(
            dataset,
            batch_size=dataset_spec["batch_size"] * 4,
            shuffle=False,
            num_workers=2,
        )
        return train_dataloader, test_dataloader, chosen_idx
    elif exp_type == "fixed_class":
        return _get_fixed_class_dataloaders(cfg, return_idx=True)
    elif exp_type == "fixed_class_v2":
        return _get_fixed_class_dataloaders(cfg, return_idx=False)
    else:
        raise ValueError("Unrecognized experiment type: {}".format(exp_type))


def _get_regular_dataloaders(cfg):
    dataset_spec = cfg["data"]
    exp_type = dataset_spec["type"]
    dataset_name = dataset_spec["name"]
    print(dataset_name)
    if dataset_name == "CIFAR-10":
        if dataset_spec["data_augmentation"]:
            tr = cifar10_transform(add_pil=False)
        else:
            tr = {
                "train": transforms.Compose([transforms.ToTensor()]),
                "test": transforms.Compose([transforms.ToTensor()]),
            }
        trainset = CIFAR10(
            root="./regular_data", train=True, download=True, transform=tr["train"],
        )
        testset = CIFAR10(
            root="./regular_data", train=False, download=True, transform=tr["test"],
        )
    elif dataset_name == "CIFAR-10-Softlabel":
        if dataset_spec["data_augmentation"]:
            tr = cifar10_transform(add_pil=True)
        else:
            tr = {
                "train": None,
                "test": transforms.Compose([transforms.ToTensor()]),
            }
        image_path = os.path.join(CUSTOM_DATA_ROOT, "CIFAR10", "Ensemble", "image.npy")
        label_path = os.path.join(CUSTOM_DATA_ROOT, "CIFAR10", "Ensemble", "label.npy")
        images_np, label_np = np.load(image_path) / 255.0, np.load(label_path)
        train_tensor_x, train_tensor_y = torch.Tensor(images_np), torch.Tensor(label_np)
        trainset = CustomTensorDataset(
            tensors=(train_tensor_x, train_tensor_y), transform=tr["train"]
        )
        testset = CIFAR10(
            root="./regular_data", train=False, download=True, transform=tr["test"],
        )
    elif dataset_name == "CIFAR-100":
        if dataset_spec["data_augmentation"]:
            tr = cifar10_transform(add_pil=False)  # use same aug as cifar10
        else:
            tr = {
                "train": transforms.Compose([transforms.ToTensor()]),
                "test": transforms.Compose([transforms.ToTensor()]),
            }
        trainset = CIFAR100(
            root="./regular_data", train=True, download=True, transform=tr["train"],
        )
        testset = CIFAR100(
            root="./regular_data", train=False, download=True, transform=tr["test"],
        )
    elif dataset_name == "MNIST":
        if dataset_spec["data_augmentation"]:
            tr = mnist_transform()
        else:
            tr = {
                "train": transforms.Compose([transforms.ToTensor()]),
                "test": transforms.Compose([transforms.ToTensor()]),
            }
        trainset = MNIST(
            root="./regular_data", train=True, download=True, transform=tr["train"],
        )
        testset = MNIST(
            root="./regular_data", train=False, download=True, transform=tr["test"],
        )
    elif dataset_name == "SVHN":
        if dataset_spec["data_augmentation"]:
            tr = svhn_transform(add_pil=True)
        else:
            tr = {
                "train": None,
                "test": None,
            }
        # trainset = SVHN(
        #     root="./regular_data", split="train", download=True, transform=tr["train"],
        # )
        # testset = SVHN(
        #     root="./regular_data", split="test", download=True, transform=tr["test"],
        # )
        base_image_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_image_{}.npy")
        base_label_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_label_{}.npy")
        image_path_0 = base_image_path.format("train", 0)
        image_path_1 = base_image_path.format("train", 1)
        label_path_0 = base_label_path.format("train", 0)
        label_path_1 = base_label_path.format("train", 1)
        test_image_path = base_image_path.format("test", 0).replace("_{}".format(0), "")
        test_label_path = base_label_path.format("test", 0).replace("_{}".format(0), "")

        images_np = np.concatenate([np.load(image_path_0), np.load(image_path_1)])
        label_np = np.concatenate([np.load(label_path_0), np.load(label_path_1)])
        train_tensor_x = torch.Tensor(images_np)
        train_tensor_y = torch.Tensor(label_np)
        test_tensor_x = torch.Tensor(np.load(test_image_path))
        test_tensor_y = torch.Tensor(np.load(test_label_path))
        trainset = CustomTensorDataset(
            tensors=(train_tensor_x, train_tensor_y), transform=tr["train"]
        )
        testset = CustomTensorDataset(
            tensors=(test_tensor_x, test_tensor_y), transform=tr["test"]
        )
    else:
        raise ValueError("Unrecognized dataset: {}".format(dataset_name))
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=dataset_spec["batch_size"], shuffle=True, num_workers=2
    )
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=dataset_spec["batch_size"] * 4, shuffle=False, num_workers=2,
    )
    return [(train_loader, test_loader)]


def _get_fixed_class_dataloaders(cfg, return_idx=False):
    dataset_spec = cfg["data"]
    exp_type = dataset_spec["type"]
    assert "fixed_class" in exp_type
    dataset_name = dataset_spec["name"]
    class_idx = dataset_spec["classes"]
    print("Fixed class dataloader of {} class {}".format(dataset_name, class_idx))

    (
        base_image_path,
        base_label_path,
        tr,
    ) = _get_base_image_path_and_augmentation_transformation(dataset_spec)

    data_loaders_pair = []
    image_path_0 = base_image_path.format("train", 0)
    image_path_1 = base_image_path.format("train", 1)
    label_path_0 = base_label_path.format("train", 0)
    label_path_1 = base_label_path.format("train", 1)
    test_image_path = base_image_path.format("test", 0).replace("_{}".format(0), "")
    test_label_path = base_label_path.format("test", 0).replace("_{}".format(0), "")

    images_np = np.concatenate([np.load(image_path_0), np.load(image_path_1)])
    label_np = np.concatenate([np.load(label_path_0), np.load(label_path_1)])
    test_image_np = np.load(test_image_path)
    test_label_np = np.load(test_label_path)
    chosen_idx = _get_classes_index(label_np, class_idx)
    test_chosen_idx = _get_classes_index(test_label_np, class_idx)

    size = min(dataset_spec["subset_size"], len(chosen_idx))
    chosen_idx = np.random.choice(chosen_idx, size, replace=False)
    if "fixed_subset_index_path" in dataset_spec and os.path.exists(
        dataset_spec["fixed_subset_index_path"]
    ):
        print(
            "Loading dataset from {}.".format(dataset_spec["fixed_subset_index_path"])
        )
        chosen_idx = np.load(dataset_spec["fixed_subset_index_path"])
    elif "fixed_subset_index_path" in dataset_spec and not os.path.exists(
        dataset_spec["fixed_subset_index_path"]
    ):
        print("Creating new dataset subset indices.")
        np.save(dataset_spec["fixed_subset_index_path"], chosen_idx)

    train_tensor_x = torch.Tensor(images_np[chosen_idx])
    train_tensor_y = torch.Tensor(label_np[chosen_idx])
    test_tensor_x = torch.Tensor(test_image_np[test_chosen_idx])
    test_tensor_y = torch.Tensor(test_label_np[test_chosen_idx])

    dataset = CustomTensorDataset(
        tensors=(train_tensor_x, train_tensor_y), transform=tr["train"]
    )

    shuffle_train = True
    if "shuffle_train" in dataset_spec and not dataset_spec["shuffle_train"]:
        print("Not shuffling the data.")
        shuffle_train = False

    train_dataloader = DataLoader(
        dataset,
        batch_size=dataset_spec["batch_size"],
        shuffle=shuffle_train,
        num_workers=2,
    )
    dataset = CustomTensorDataset(
        tensors=(test_tensor_x, test_tensor_y), transform=tr["test"]
    )
    test_dataloader = DataLoader(
        dataset,
        batch_size=dataset_spec["batch_size"] * 4,
        shuffle=False,
        num_workers=2,
    )
    # return train_dataloader, test_dataloader, chosen_idx
    if not return_idx:
        return [
            (train_dataloader, test_dataloader),
            (train_dataloader, test_dataloader),
        ]
    else:
        return train_dataloader, test_dataloader, chosen_idx


def _get_split_subset_dataloaders(cfg, return_idx=False):
    dataset_spec = cfg["data"]
    exp_type = dataset_spec["type"]
    dataset_name = dataset_spec["name"]
    print("split subset for: ", dataset_name)
    (
        base_image_path,
        base_label_path,
        tr,
    ) = _get_base_image_path_and_augmentation_transformation(dataset_spec)
    all_data_loaders = []
    for i in range(2):
        data_loaders_pair = []
        for mode in ["train", "test"]:
            image_path = base_image_path.format(mode, i)
            label_path = base_label_path.format(mode, i)
            if mode == "test":
                image_path = image_path.replace("_{}".format(i), "")
                label_path = label_path.replace("_{}".format(i), "")

            image_np, label_np = np.load(image_path), np.load(label_path)
            chosen_idx = np.arange(len(image_np))

            ###############################################################
            if mode == "train":
                chosen_idx = np.random.choice(
                    chosen_idx, dataset_spec["subset_size"], replace=False
                )
                print("Using a subset of size: {}".format(len(chosen_idx)))
                index_path = "fixed_subset_index_path_{}".format(i)
                if index_path in dataset_spec and os.path.exists(
                    dataset_spec[index_path]
                ):
                    print(
                        "Loading subset using index from {}.".format(
                            dataset_spec[index_path]
                        )
                    )
                    chosen_idx = np.load(dataset_spec[index_path])
                elif index_path in dataset_spec and not os.path.exists(
                    dataset_spec[index_path]
                ):
                    print(
                        "Creating new dataset subset indices at {}.".format(index_path)
                    )
                    assert dataset_spec["subset_size"] < len(
                        image_np
                    ), "subset size larger than total number of image"
                    np.save(index_path, chosen_idx)
            ###############################################################

            image_np, label_np = image_np[chosen_idx], label_np[chosen_idx]
            tensor_x, tensor_y = torch.Tensor(image_np), torch.Tensor(label_np)
            dataset = CustomTensorDataset(
                tensors=(tensor_x, tensor_y), transform=tr[mode]
            )
            dataloader = DataLoader(
                dataset,
                batch_size=dataset_spec["batch_size"],
                shuffle=True,
                num_workers=2,
            )
            data_loaders_pair.append(dataloader)
        all_data_loaders.append(data_loaders_pair)
    return all_data_loaders


def _get_base_image_path_and_augmentation_transformation(dataset_spec):
    dataset_name = dataset_spec["name"]
    if dataset_name == "CIFAR-10":
        if "data_augmentation" in dataset_spec and dataset_spec["data_augmentation"]:
            tr = cifar10_transform(add_pil=True)
        else:
            tr = {
                "train": None,
                "test": None,
            }
        base_image_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_image_{}.npy")
        base_label_path = os.path.join(PARTITION_ROOT, "CIFAR10", "{}_label_{}.npy")
    elif dataset_name == "CIFAR-100":
        if "data_augmentation" in dataset_spec and dataset_spec["data_augmentation"]:
            tr = cifar10_transform(
                add_pil=True
            )  # share the same augmentation as cifar10
        else:
            tr = {
                "train": None,
                "test": None,
            }
        base_image_path = os.path.join(PARTITION_ROOT, "CIFAR100", "{}_image_{}.npy")
        base_label_path = os.path.join(PARTITION_ROOT, "CIFAR100", "{}_label_{}.npy")
    elif dataset_name == "SVHN":
        if "data_augmentation" in dataset_spec and dataset_spec["data_augmentation"]:
            tr = svhn_transform(add_pil=True)
        else:
            tr = {
                "train": None,
                "test": None,
            }
        base_image_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_image_{}.npy")
        base_label_path = os.path.join(PARTITION_ROOT, "SVHN", "{}_label_{}.npy")
    elif dataset_name == "MNIST":
        if dataset_spec["data_augmentation"]:
            tr = mnist_transform(add_pil=True)
        else:
            tr = {
                "train": None,
                "test": None,
            }
        base_image_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_image_{}.npy")
        base_label_path = os.path.join(PARTITION_ROOT, "MNIST", "{}_label_{}.npy")
    else:
        raise ValueError("Unrecognized dataset: {}".format(dataset_name))
    return base_image_path, base_label_path, tr


def _get_classes_index(label, class_idx):
    all_idx = np.arange(len(label))
    selected_idx = []
    for c in class_idx:
        selected_idx.append(all_idx[np.where(label == c)])
    chosen_idx = np.concatenate(selected_idx)
    return chosen_idx


###########################################################################
########################### Data augmentation #############################
###########################################################################


def cifar10_transform(add_pil=False):
    ops = [transforms.ToPILImage()] if add_pil else []
    ops += [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
    transform_train = transforms.Compose(ops)
    ops = [transforms.ToPILImage()] if add_pil else []
    ops += [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
    transform_test = transforms.Compose(ops)
    return {"train": transform_train, "test": transform_test}


def mnist_transform(add_pil=False):
    ops = [transforms.ToPILImage()] if add_pil else []
    ops += [
        transforms.RandomCrop(28, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
    transform_train = transforms.Compose(ops)
    transform_test = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]
    )
    return {"train": transform_train, "test": transform_test}


def svhn_transform(add_pil=False):
    ops = [transforms.ToPILImage()] if add_pil else []
    ops += [
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
    ]
    transform_train = transforms.Compose(ops)
    ops = [transforms.ToPILImage()] if add_pil else []
    ops += [
        transforms.ToTensor(),
    ]
    transform_test = transforms.Compose(ops)
    return {"train": transform_train, "test": transform_test}


###################################################################
########################### Cifar10-C #############################
###################################################################


def load_cifar10_image(corruption_type,
                       clean_cifar_path,
                       corruption_cifar_path,
                       corruption_severity=0,
                       datatype='test',
                       num_samples=50000,
                       normalize=True,
                       seed=1):
    """
    Available types:
    [spatter, jpeg_compression, motion_blur, impulse_noise, saturate]
    Returns:
        pytorch dataset object
    """
    assert datatype == 'test' or datatype == 'train'
    training_flag = True if datatype == 'train' else False

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    transform = [transforms.ToTensor()]
    if normalize:
        transform.append(transforms.Normalize(mean, std))
    transform = transforms.Compose(transform)

    # transform = transforms.Compose([
    #     # transforms.Resize(224),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean, std),
    #     # transforms.Resize(32),
    # ])

    dataset = CIFAR10(clean_cifar_path,
                      train=training_flag,
                      transform=transform,
                      download=True)

    if corruption_severity > 0:
        assert not training_flag
        path_images = os.path.join(corruption_cifar_path, corruption_type + '.npy')
        path_labels = os.path.join(corruption_cifar_path, 'labels.npy')
        assert type(corruption_severity) == int, "corruption severity must be an integer"
        dataset.data = np.load(path_images)[(corruption_severity - 1) * 10000:corruption_severity * 10000]
        dataset.targets = list(np.load(path_labels)[(corruption_severity - 1) * 10000:corruption_severity * 10000])
        dataset.targets = [int(item) for item in dataset.targets]

    # randomly permute data
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    number_samples = dataset.data.shape[0]
    index_permute = torch.randperm(number_samples)
    dataset.data = dataset.data[index_permute]
    dataset.targets = np.array([int(item) for item in dataset.targets])
    dataset.targets = dataset.targets[index_permute].tolist()

    # randomly subsample data
    if datatype == 'train' and num_samples < 50000:
        indices = torch.randperm(50000)[:num_samples]
        dataset = torch.utils.data.Subset(dataset, indices)
        print('number of training data: ', len(dataset))
    if datatype == 'test' and num_samples < 10000:
        indices = torch.randperm(10000)[:num_samples]
        dataset = torch.utils.data.Subset(dataset, indices)
        print('number of test data: ', len(dataset))

    return dataset
